Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model post process for zero stage3 training #17187

Merged
merged 10 commits into from
Sep 22, 2023
Merged

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Aug 16, 2023

Model post process for zero stage3 training

This is the last change to make single GPU/Multiple GPUs run pass.

Design details: https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9

PyTorch runs with ZeROOffloadSubscriber:

  model = prepare_model(...)
  from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
  configure_ort_compatible_zero_stage3()

ORTModule runs with ZeROOffloadSubscriber:

  os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1'
  from onnxruntime.training.ortmodule import ORTModule
  model = ORTModule(self.model)

It will be fairly easy to debug convergence issue if both ORT and PyTorch can run the same offload path.

Motivation and Context

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Aug 16, 2023
askhade
askhade previously approved these changes Aug 22, 2023
Copy link
Contributor

@askhade askhade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from pengwa/zero_offload to main August 24, 2023 16:15
@pengwa pengwa dismissed askhade’s stale review August 24, 2023 16:15

The base branch was changed.

@pengwa pengwa force-pushed the pengwa/zero_post_process branch from 86704cf to 4e59594 Compare August 24, 2023 16:48
tensor_input_dtypes: List[torch.onnx.TensorProtoDataType],
) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]:
# output = input.matmul(weight.t())
tensor_input_shapes[0] # input

Check notice

Code scanning / CodeQL

Statement has no effect

This statement has no effect.
Comment on lines +51 to +52
# if ctx.current_step >= 0:
# print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}")

Check notice

Code scanning / CodeQL

Commented-out code

This comment appears to contain commented-out code.
Copy link
Contributor

@askhade askhade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa pengwa merged commit 6b7bce5 into main Sep 22, 2023
@pengwa pengwa deleted the pengwa/zero_post_process branch September 22, 2023 00:54
@pengwa
Copy link
Contributor Author

pengwa commented Sep 22, 2023

Thank you a lot @askhade!!!

kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Model post process for zero stage3 training

This is the last change to make single GPU/Multiple GPUs run pass. 

Design details:
https://microsoft.sharepoint.com/:p:/t/ONNX2/EfNfJ43necpIoPI6x5M2zvYBVbfjoPQmG4Boc_F7-tHm1w?e=ekQwA6&nav=eyJzSWQiOjMxNiwiY0lkIjoxMDE1Nzg3NDZ9

`PyTorch` runs with ZeROOffloadSubscriber:

```
  model = prepare_model(...)
  from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
  configure_ort_compatible_zero_stage3()
```

`ORTModule` runs with ZeROOffloadSubscriber:

```
  os.environ['ORTMODULE_ENABLE_ZERO_STAGE3'] = '1'
  from onnxruntime.training.ortmodule import ORTModule
  model = ORTModule(self.model)
```

It will be fairly easy to debug convergence issue if both ORT and
PyTorch can run the same offload path.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants